import os
import re
import sys
import base64
import subprocess
from pathlib import Path
from threading import Timer
import openai
import inference

from openai import OpenAI
client = OpenAI(
  api_key=os.getenv("OPENAI_API_KEY")
)

model = "GPT4o"

PARAMS = {
    "temperature": 0.7,
    "num_samples": 8,
    "max_attempts": 20,
    "mem_capacity": 3,
    "outer_loops": 6,  # N: number of outer loops
    "inner_loops": 3   # M: number of inner loops (total attempts = N*M)
}

def read_file(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()

def write_file(path, content):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        f.write(content)

def load_domain_info(domain_name):
    base_dir = os.path.join(domain_name)
    problem_description = read_file(os.path.join(base_dir, f"{domain_name}_domain.txt"))
    initial_state = read_file(os.path.join(base_dir, "initial_state.txt"))
    return problem_description, initial_state

def verify_diagram_encoding(problem_description, diagram_encoding, initial_state, model):
    prompt_parts = [f"""
    Consider the following problem description:
    {problem_description}
    The following is a description of the initial state:
    {initial_state}
    The diagram encoding (which is made of a description of the relative/absolute position, relative size, status (includes color if relevant), and a text identifier for each object in the scene) of the initial state is:
    {diagram_encoding}
    Your task is to check whether this diagram encoding is correct and encodes an intuitive visual representation of the scene (information about cost of actions can be ignored). Note that there should be at least one statement (a statement inlcudes a text identifier, position, size, etc) for each object/location in the scene. For every item in the diagram encoding, write a description about wether the object description matches the state description and the relative position/size of the object matches the initial state description. Think step by step. The final answer must be a yes or a no. If any of the descriptions is wrong, say no, yes otherwise, in the following format.
    ```yes_no
    <final answer, yes or no>
    ```
    If your final answer is 'no', provide a very short phrase describing the error:
    ```error
    <error description>
    ```
    """]
    response = inference.get_model_response(prompt_parts, model)
    validity = inference.extract_content(response, "yes_no")
    error_description = inference.extract_content(response, "error")
    return validity, error_description


def rank_diagram_encodings(problem_description, initial_state, candidate_encodings):
    enumerated_list = ""
    for i, enc in enumerate(candidate_encodings):
        enumerated_list += f"\nDiagram Encoding {i+1}:\n{enc}\n"

    prompt_parts = [f"""
    Given the following problem description:
    {problem_description}

    where the initial state is:
    {initial_state}

    We generated the following candidate diagram encodings for this initial state.
    We want to rank them from best to worst, focusing on:
    - How accurately and intuitively they capture the scene geometry.
    - How well they describe relationships between objects.
    - How consistent they are with the problem statement.
    - Whether the shape, position, relative size, and status of all objects is described clearly.

    {enumerated_list}

    Please provide a ranking in the format:

    ```ranking
    <order of best to worst, listing them by their index [1..N], separated by commas>
    ```

    Also, provide a short explanation for your ranking in plain text above that code block.
    """]

    response = inference.get_model_response(prompt_parts, model)
    reasoning = response  # store the full response for logging
    ranking = inference.extract_content(response, "ranking")
   
    # Convert IDs to integers and map to states
    ranked_ids = [int(id.strip()) for id in ranking if id.strip().isdigit()]
    id_to_encoding = {i+1: encoding for i, encoding in enumerate(candidate_encodings)}
    ranked_encodings = [id_to_encoding[id] for id in ranked_ids if id in id_to_encoding]
    return ranked_encodings, reasoning

def generate_diagram_encoding(problem_description, initial_state, prev_section, temp):
    prompt_parts = f"""You are creating a diagram_encoding for the initial state of a problem:

    Target problem description:
    {problem_description}

    Initial State:
    {initial_state}

    A diagram encoding is a collection of descriptions for each object indicating:
    - shape (should be 2 dimenional: e.g. rectangle, circle, etc)
    - relative/absolute position
    - relative/absolute size
    - status (e.g., color, constraints, ignore any information about the costs)
    - text identifier

    We want a thorough, accurate, and intuitive encoding. If an object exists in the state description, it must be included here. 
    here's an exmaple for a random hypothetical problem:
    We have a building with floors labeled 0 through 7 (n0 to n7). Each floor nX is considered “above” any floor nY where X < Y, meaning n0 < n1 < n2 < … < n7 in ascending order. The goal is to move a set of passangers from their current floor to their destination floor using a set of elevators.

    There are three elevators:
    • slow0, a slow elevator currently at floor 3 (n3). It has 0 passengers right now and has a maximum capacity of 3 passengers. It can reach floors 0, 1, 2, 3, and 4 (n0 to n4).  
    • slow1, another slow elevator, currently at floor 7 (n7). It holds 0 passengers at the moment and can also hold up to 3 passengers. It can serve floors 4, 5, 6, and 7 (n4 to n7).  
    • fast0, a fast elevator, is at floor 0 (n0) with 1 passenger on board, passanger p4. It can hold up to 2 passengers and can reach floors 0, 2, 4, and 6 (n0, n2, n4, n6).  

    We have five passengers on specific floors:  
    • p0 is on floor 2 (n2).  
    • p1 is on floor 5 (n5).  
    • p2 is on floor 5 (n5).  
    • p3 is on floor 4 (n4).  
    • p4 is on floor 0 (n0) on board of elevator fast0.  

    An incomplete diagram encoding of the following scene is: (partially provided):
    (text/identifier: floor_0, shape: rectangle, size: large and horizontally long, position: bottom-most in the floor grid, status: contains passenger p4, contains elevator fast0)  
    (text/identifier: floor_1, shape: rectangle, size: large and horizontally long, position: above floor_0, status: empty) 
    ... for all floors
    (text/identifier: slow0, shape: rectangle, size: smaller than each floor, position: inside floor_3 to the right side of floor, status: 0/3 passanegrs)  
    (text/identifier: slow1, shape: rectangle, size: smaller than each floor, position: inside floor_7 to the right side of floor, status: 0/3 passanegrs)  
    (text/identifier: fast0, shape: rectangle, size: smaller than each floor, position: inside floor_0 to the right side of floor, status: 1/3 passanegrs (p4 on board))  
    ... for all elevators
    (text/identifier: p0, shape: rectangle, size: smaller than each elevator, position: inside floor_2 on the left side of floor, status: on floor_2)  
    (text/identifier: p1, shape: rectangle, size: smaller than each elevator, position: inside floor_5 on the left side of floor, status: on floor_5) 
    (text/identifier: p2, shape: rectangle, size: smaller than each elevator, position: inside floor_5 on the left side next to p1's rectangle, status: on floor_5) 
    (text/identifier: p4, shape: rectangle, size: smaller than each elevator, position: inside floor_0, next to elevator fast0 on the right side of floor, status: on board elevator fast0)   
    ... for all passangers
    
    Use clear spatial relationships (e.g., "to the left of X", "above Y") or absolute positions if relevant. 
    Keep it concise but complete.

    Now produce a new, unique diagram encoding for the target problem. Return it in the following format:

    ```answer_text
    < your final answer>
    ```
    """
    
    #Here are previous generations we have so far:
    #{prev_section}
    #Make sure it is meaningfully different from the previous generations. 

    response = inference.get_model_response(prompt_parts, model, temp)
    diagram_encoding = inference.extract_content(response, "text")
    return diagram_encoding

def get_1shot_diagram_encoding_ini(domain_name):
    problem_description, initial_state = load_domain_info(domain_name)

    # Prepare folder structure
    base_task_dir = os.path.join(domain_name, "one_shot", "ini_diagram_encoding")
    attempts_dir = os.path.join(base_task_dir, "attempts")
    best_dir = base_task_dir
    os.makedirs(attempts_dir, exist_ok=True)
    os.makedirs(best_dir, exist_ok=True)

    # We'll store valid candidates here (for final ranking)
    valid_encodings = []
    attempt_count = 0

    # Outer and inner loop for sampling
    for outer_idx in range(PARAMS["outer_loops"]):
        # Reset memory each new outer loop
        prev_section = ""
        for inner_idx in range(PARAMS["inner_loops"]):
            # If we've already collected enough valid samples, stop
            if len(valid_encodings) >= PARAMS["num_samples"]:
                break
            attempt_count += 1

            # Generate an encoding
            new_generation = generate_diagram_encoding(
                problem_description,
                initial_state,
                prev_section,
                PARAMS["temperature"]
            )

            print(f"Generated diagram encoding attempt {attempt_count}")

            # Save attempt
            attempt_path = os.path.join(attempts_dir, f"attempt_{attempt_count}.txt")
            write_file(attempt_path, new_generation)

            # Verify
            is_valid, err_msg = verify_diagram_encoding(
                problem_description,
                new_generation,
                initial_state,
                model
            )

            if is_valid:
                print(f"Attempt {attempt_count} was verified")
                valid_encodings.append(new_generation)
                # Keep track of previous generations in this same outer loop
                prev_section += f"\n- Previous generation:\n{new_generation}\n"
            else:
                print(f"Attempt {attempt_count} failed the verification test: {err_msg}")
                with open(attempt_path, 'a', encoding='utf-8') as f:
                    f.write(f"\nVERIFICATION: FAILED\nREASON: {err_msg}\n")

        if len(valid_encodings) >= PARAMS["num_samples"]:
            break

    # If no valid encodings
    if not valid_encodings:
        error = "[ERROR] No valid diagram encodings found; pipeline ended."
        print(error)
        return False, error

    # Rank the valid encodings
    print("Ranking diagram encodings ...")
    ranked, reasoning = rank_diagram_encodings(problem_description, initial_state, valid_encodings)

    # Take the best candidate
    best_candidate = ranked[0]
    best_file = os.path.join(best_dir, "best_candidate.txt")
    write_file(best_file, best_candidate)

    print("[SUCCESS] Found at least one valid diagram encoding.")
    print(f"[INFO] Best candidate saved at: {best_file}")
    return True, None

def main():
    if len(sys.argv) < 2:
        print("Usage: python one_shot_diagram_encoding.py <domain_name>")
        sys.exit(1)

    domain_name = sys.argv[1]
    
    print("Started on getting verified and ranked diagram encoding for the initial state")

    get_1shot_diagram_encoding_ini(domain_name)

if __name__ == "__main__":
    main()